# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
 
# recurrent_units.py
# Version 2.0
#
# Some recurrent units can be used in recurrent layer group, 
#   to use these units, import this module in your config_file:
#     import trainer.recurrent_units 
# 

from paddle.trainer.config_parser import *

# long short term memory, can be used in recurrent machine
# *inputs* must be a list of Projections, for example:
#   inputs = [FullMatrixProjection("input_layer_name")],
# *para_prefix* defines parameter names, if the *para_prefix* of 
#   two LstmRecurrentUnit is same, they share same parameters
# *out_memory* can be defined outside if it's used outside
def LstmRecurrentUnit(name, size, 
                      active_type, state_active_type, gate_active_type, 
                      inputs, para_prefix = None, 
                      error_clipping_threshold = 0,
                      out_memory = None):

    if para_prefix is None: 
        para_prefix = name
    if out_memory is None:
        out_memory = Memory(name = name, size = size)

    state_memory = Memory(name = name + "_" + "state", size = size)
 
    Layer(
          name = name + "_" + "input_recurrent",
          type = "mixed",
          size = size * 4, #(input_s, input_gate, forget_gate, output_gate)
          error_clipping_threshold = error_clipping_threshold,
          bias = Bias(initial_std = 0, 
                      parameter_name = para_prefix + "_input_recurrent.b"),
          inputs = inputs + [
            FullMatrixProjection(out_memory,
                                 parameter_name = para_prefix + "_input_recurrent.w"),
            ],
    )
    LstmStepLayer(
          name = name,
          size = size,
          bias = Bias(parameter_name = para_prefix + "_check.b"),
          inputs = [name + "_" + "input_recurrent", state_memory],
          active_type = active_type,
          active_gate_type = gate_active_type,
          active_state_type = state_active_type,
    )
    GetOutputLayer(
          name = name + "_" + "state",
          size = size,
          inputs = Input(name, input_layer_argument = "state"),
    )

def LstmRecurrentUnitNaive(name, size, 
                           active_type, state_active_type, gate_active_type, 
                           inputs, para_prefix = None, 
                           error_clipping_threshold = 0,
                           out_memory = None):

    if para_prefix is None: 
        para_prefix = name
    if out_memory is None:
        out_memory = Memory(name = name, size = size)

    state_memory = Memory(name = name + "_" + "state", size = size)
 
    Layer(
          name = name + "_" + "input_recurrent",
          type = "mixed",
          size = size * 4, #(input_s, input_gate, forget_gate, output_gate)
          error_clipping_threshold = error_clipping_threshold,
          bias = Bias(initial_std = 0, 
                      parameter_name = para_prefix + "_input_recurrent.b"),
          inputs = inputs + [
            FullMatrixProjection(out_memory,
                                 parameter_name = para_prefix + "_input_recurrent.w"),
            ],
    )
    ExpressionLayer(
          name = name + "_" + "input_s",
          size = size,
          active_type = active_type,
          inputs = [IdentityOffsetProjection(name + "_" + "input_recurrent", offset=0)],
    )
    ExpressionLayer(
          name = name + "_" + "input_gate",
          active_type = gate_active_type,
          inputs = [IdentityOffsetProjection(name + "_" + "input_recurrent", offset=size),
                    DotMulProjection(state_memory,
                                     parameter_name = para_prefix + "_input_check.w")],
    )
    ExpressionLayer(
          name = name + "_" + "forget_gate",
          active_type = gate_active_type,
          inputs = [IdentityOffsetProjection(name + "_" + "input_recurrent", offset=size*2),
                    DotMulProjection(state_memory,
                                     parameter_name = para_prefix + "_forget_check.w")],
    )
    ExpressionLayer(
          name = name + "_" + "state",
          inputs = [DotMulOperator([name + "_" + "input_s",
                                    name + "_" + "input_gate"]),
                    DotMulOperator([state_memory, 
                                    name + "_" + "forget_gate"]),
                    ],
    )
    ExpressionLayer(
          name = name + "_" + "output_gate",
          active_type = gate_active_type,
          inputs = [IdentityOffsetProjection(name + "_" + "input_recurrent", offset=size*3),
                    DotMulProjection(name + "_" + "state",
                                     parameter_name = para_prefix + "_output_check.w")],
    )
    ExpressionLayer(
          name = name + "_" + "state_atv",
          active_type = state_active_type,
          inputs = IdentityProjection(name + "_" + "state"),
    )
    ExpressionLayer(
          name = name,
          inputs = DotMulOperator([name + "_" + "state_atv",
                                   name + "_" + "output_gate"]),
    )

# like LstmRecurrentUnit, but it's a layer group.
# it is equivalent to LstmLayer
def LstmRecurrentLayerGroup(name, size, 
                            active_type, state_active_type, gate_active_type, 
                            inputs, para_prefix = None,
                            error_clipping_threshold = 0,
                            seq_reversed = False):

    input_layer_name = name + "_" + "transform_input"
    Layer(
          name = input_layer_name,
          type = "mixed",
          size = size * 4,
          active_type = "",
          bias = False,
          inputs = inputs,
    )

    RecurrentLayerGroupBegin(name + "_layer_group", 
                             in_links = [input_layer_name], 
                             out_links = [name],
                             seq_reversed = seq_reversed)

    LstmRecurrentUnit(
        name = name,
        size = size,
        active_type = active_type,
        state_active_type = state_active_type,
        gate_active_type = gate_active_type,
        inputs = [IdentityProjection(input_layer_name)],
        para_prefix = para_prefix,
        error_clipping_threshold = error_clipping_threshold,
        )

    RecurrentLayerGroupEnd(name + "_layer_group")



# gated recurrent unit, can be used in recurrent machine
# *inputs* should be a list of Projections, for example:
#   inputs = [FullMatrixProjection("input_layer_name")],
# *para_prefix* defines parameter names, if the *para_prefix* of 
#   two GatedRecurrentUnit is same, they share same parameters
# *out_memory* can be defined outside if it's used outside

def GatedRecurrentUnit(name, size, 
                       active_type, gate_active_type, 
                       inputs, para_prefix = None, 
                       error_clipping_threshold = 0,
                       out_memory = None):
    if type_of(inputs) == str: #only used by GatedRecurrentLayerGroup
        input_layer_name = inputs
    else:
        input_layer_name = name + "_" + "transform_input"
        Layer(
            name = input_layer_name,
            type = "mixed",
            size = size * 3,
            active_type = "",
            bias = False,
            inputs = inputs,
        )

    if para_prefix is None: 
        para_prefix = name
    if out_memory is None:
        out_memory = Memory(name = name, size = size)

    GruStepLayer(
          name = name,
          size = size,
          bias = Bias(parameter_name = para_prefix + "_gate.b"),
          inputs = [input_layer_name,
                    Input(out_memory, parameter_name = para_prefix + "_gate.w")],
          active_type = active_type,
          active_gate_type = gate_active_type,
    )

def GatedRecurrentUnitNaive(name, size, 
                            active_type, gate_active_type, 
                            inputs, para_prefix = None, 
                            error_clipping_threshold = 0,
                            out_memory = None):

    if type_of(inputs) == str: #only used by GatedRecurrentLayerGroup
        input_layer_name = inputs
    else:
        input_layer_name = name + "_" + "transform_input"
        Layer(
            name = input_layer_name,
            type = "mixed",
            size = size * 3,
            active_type = "",
            bias = False,
            inputs = inputs,
        )

    if para_prefix is None: 
        para_prefix = name
    if out_memory is None:
        out_memory = Memory(name = name, size = size)

    Layer(
          name = name + "_" + "update_gate",
          type = "mixed",
          size = size, 
          active_type = gate_active_type,
          error_clipping_threshold = error_clipping_threshold,
          bias = Bias(initial_std = 0, parameter_name = para_prefix + "_update_gate.b"),
          inputs = [IdentityOffsetProjection(input_layer_name, offset=0),
                    FullMatrixProjection(out_memory,
                                         parameter_name = para_prefix + "_update_gate.w")],
    )
    Layer(
          name = name + "_" + "reset_gate",
          type = "mixed",
          size = size, 
          active_type = gate_active_type,
          error_clipping_threshold = error_clipping_threshold,
          bias = Bias(initial_std = 0, parameter_name = para_prefix + "_reset_gate.b"),
          inputs = [IdentityOffsetProjection(input_layer_name, offset=size),
                    FullMatrixProjection(out_memory,
                                         parameter_name = para_prefix + "_reset_gate.w")],
    )
    ExpressionLayer(
          name = name + "_" + "reset_output",
          inputs = DotMulOperator([out_memory, name + "_" + "reset_gate"]),
    )
    Layer(
          name = name + "_" + "output_candidate",
          type = "mixed",
          size = size, 
          active_type = active_type,
          error_clipping_threshold = error_clipping_threshold,
          bias = Bias(initial_std = 0, parameter_name = para_prefix + "_output_candidate.b"),
          inputs = [IdentityOffsetProjection(input_layer_name, offset=size*2),
                    FullMatrixProjection(name + "_" + "reset_output",
                                         parameter_name = para_prefix + "_output_candidate.w")],
    )
    ExpressionLayer( #element-wise interpolation
          name = name,
          inputs = [IdentityProjection(out_memory),
                    DotMulOperator([out_memory, 
                                    name + "_" + "update_gate"], scale=-1.0),
                    DotMulOperator([name + "_" + "output_candidate", 
                                    name + "_" + "update_gate"]),
                    ],
    )

# like GatedRecurrentUnit, but it's a layer group.
# it is equivalent to GatedRecurrentLayer.
def GatedRecurrentLayerGroup(name, size, 
                             active_type, gate_active_type, 
                             inputs, para_prefix = None,
                             error_clipping_threshold = 0,
                             seq_reversed = False):

    input_layer_name = name + "_" + "transform_input"
    Layer(
          name = input_layer_name,
          type = "mixed",
          size = size * 3,
          active_type = "",
          bias = False,
          inputs = inputs,
    )

    RecurrentLayerGroupBegin(name + "_layer_group", 
                             in_links = [input_layer_name], 
                             out_links = [name],
                             seq_reversed = seq_reversed)

    GatedRecurrentUnit(
        name = name,
        size = size,
        active_type = active_type,
        gate_active_type = gate_active_type,
        inputs = input_layer_name, #transform outside
        para_prefix = para_prefix,
        error_clipping_threshold = error_clipping_threshold,
        )

    RecurrentLayerGroupEnd(name + "_layer_group")

